Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saliency Map for GAT #435

Merged
merged 19 commits into from Aug 8, 2019
Merged

Saliency Map for GAT #435

merged 19 commits into from Aug 8, 2019

Conversation

sktzwhj
Copy link
Contributor

@sktzwhj sktzwhj commented Jun 28, 2019

Hi @adocherty and @youph ,

This PR adds the saliency map for GAT model. @adocherty had a look at the previous implementation and this one adapts the implementation to the new generator APIs.

-Huijun

@sktzwhj sktzwhj added the ml label Jun 28, 2019
@sktzwhj sktzwhj added this to the v0.4 Sprint 7 milestone Jun 28, 2019
@sktzwhj sktzwhj requested review from adocherty and youph June 28, 2019 04:04
@sktzwhj sktzwhj self-assigned this Jun 28, 2019
@review-notebook-app
Copy link

Check out this pull request on ReviewNB: https://app.reviewnb.com/stellargraph/stellargraph/pull/435

You'll be able to see visual diffs and write comments on notebook cells. Powered by ReviewNB.

This is typically the logit or softmax output.
"""

def __init__(self, model, generator):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function __init__ has a Cognitive Complexity of 8 (exceeds 5 allowed). Consider refactoring.

)
return np.squeeze(total_gradients * X_diff, 0)

def get_integrated_link_masks(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_integrated_link_masks has a Cognitive Complexity of 6 (exceeds 5 allowed). Consider refactoring.

A_val = self.A
# Execute the function to compute the gradient
self.set_ig_values(1.0, 0.0)
if self.is_sparse and not sp.issparse(A_val):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identical blocks of code found in 2 locations. Consider refactoring.

A_val = self.A
# Execute the function to compute the gradient
self.set_ig_values(alpha, non_exist_edge)
if self.is_sparse and not sp.issparse(A_val):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identical blocks of code found in 2 locations. Consider refactoring.

This is typically the logit or softmax output.
"""

def __init__(self, model, generator):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cyclomatic complexity is too high in method init. (6)

@codeclimate
Copy link

codeclimate bot commented Jun 28, 2019

Code Climate has analyzed commit 1dbe778 and detected 8 issues on this pull request.

Here's the issue category breakdown:

Category Count
Complexity 5
Duplication 2
Security 1

View more on Code Climate.

@sktzwhj sktzwhj closed this Jul 1, 2019
@sktzwhj sktzwhj reopened this Jul 1, 2019
Copy link
Contributor

@adocherty adocherty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notebook

stellargraph_dev/demos/interpretability/gat/node-link-importance-demo-gat.ipynb

  • Jupyter notebook does not have title.
  • When you introduce importances and saliency maps you don’t seem to describe the assumptions of the functions. I believe the integrated gradient methods assume that the features are all binary. We should state this assumption clearly in the description. Also indicate what we should do if the features are not binary.
  • Sanity checks and others are best moved from notebooks to unit tests:
    • delta & non_exist_edge variable check
    • serialisation check
    • ego-graph integrate_link_mask check
    • masked_array check
  • Let’s do the same here as we discussed for GCN - I think the functions should all take a node ID not an index so we don’t confuse the different indices for graph vs pandas features.
  • How about changing the name of get_integrated_link_masks to get_link_importance to match the get_node_importance function.
  • As you have used sorted_indices[::-1] in the preceding cell, in the line:
print('Top {} most important links by integrated gradients are {}'.format(topk, integrated_link_importance_rank[-topk:]))

Shouldn’t integrated_link_importance_rank[-topk:] be integrated_link_importance_rank[:topk]?

  • The nodes in G_ego are by IDs, therefore I think we should use target_nid not target_idx here:
  if nid == target_nid:
    continue
  • In the visualisation code and elsewhere there are many instances of list(G.nodes()).index(·) that can be replaced with graph_nodes.index(·)

layer/graph_attention.py

  • I think we need to subtract the row maximum from dense before the exponential to avoid floating point errors in the exp function:
W = (
    (1 - self.non_exist_edge) * self.delta * A
    + self.non_exist_edge
    * (
        A
        + self.delta * (K.ones(shape=[N, N], dtype=*”float”*) - A)
        + K.eye(N)
    )
) * K.exp(dense - K.max(dense, axis=1, keepdims=True))
dense = W / K.sum(W, axis=1, keepdims=True)

This means the current tests fail, but I don’t believe this is bad – the results from this are different from the implementation without the subtraction. The results in the notebook are the same with this normalisation.

There is, as you point out, an issue with the number of non-zero elements in the link importance calculation; however, I think this is just an issue of floating point accuracy, and counting elements above a small threshold works fine: i.e. using the following

print("Number of non-zero elements in integrate_link_importance: {}".format(np.sum(np.abs(integrate_link_importance) > 1e-8)))

gives:

Number of edges in the ego graph: 210
Number of non-zero elements in integrate_link_importance: 210

I would like to add a unit test to the tests/layer/test_graph_attention.py file that checks that the results from the implementation with saliency_map_support=False are the same as those for saliency_map_support=True – such as the test_apply_average_with_neighbours method. Would you like to add this test?

utils/saliency_maps_gat

  • I think we should move these files to utils/saliency_maps and name them integrated_gradients_gat.py and saliency_gat.py . Additionally rename the IntegratedGradients class to IntegratedGradientsGAT and GradientSaliency to GradientSaliencyGAT. This way we can import all saliency objects to the same namespace in utils/saliency_maps/__init__.py.
  • As in the comments on the notebook above, let’s use the node IDs instead of index in all functions.
  • In get_integrated_node_masks the features are taken from zero to one.
    • This seems to assume that the features are all binary. What happens if they are not? We should state this assumption clearly in the class documentation.
    • This seems to only consider the importance of features that are one, as any features that are zero will be zero for all steps. I would have guessed that there would also be a function that takes those features from 1 to 0, you talked about this in the paper but I forget what you said now!
  • There should be a description of what each method does and the assumptions made.
  • In the argument list try not to put multiple variables on the same line, rather describe each variable separately on its own line.

@sktzwhj
Copy link
Contributor Author

sktzwhj commented Jul 31, 2019

Notebook

stellargraph_dev/demos/interpretability/gat/node-link-importance-demo-gat.ipynb

  • Jupyter notebook does not have title.
    I have added a title - Interpreting Nodes and Edges by Saliency Maps in GAT
  • When you introduce importances and saliency maps you don’t seem to describe the assumptions of the functions. I believe the integrated gradient methods assume that the features are all binary. We should state this assumption clearly in the description. Also indicate what we should do if the features are not binary.

IG does seem to work well for binary features compared with vanilla methods. However, it does not assume binary features. In fact, it was initially used in the image domain where features are not binary.

  • Sanity checks and others are best moved from notebooks to unit tests:

    • delta & non_exist_edge variable check
    • serialisation check
    • ego-graph integrate_link_mask check
    • masked_array check

Fixed. These sanity checks are now in the unit tests.

  • Let’s do the same here as we discussed for GCN - I think the functions should all take a node ID not an index so we don’t confuse the different indices for graph vs pandas features.
  • How about changing the name of get_integrated_link_masks to get_link_importance to match the get_node_importance function.
    Fixed.
  • As you have used sorted_indices[::-1] in the preceding cell, in the line:
print('Top {} most important links by integrated gradients are {}'.format(topk, integrated_link_importance_rank[-topk:]))

Shouldn’t integrated_link_importance_rank[-topk:] be integrated_link_importance_rank[:topk]?

Good catch! Fixed.

  • The nodes in G_ego are by IDs, therefore I think we should use target_nid not target_idx here:
  if nid == target_nid:
    continue
  • In the visualisation code and elsewhere there are many instances of list(G.nodes()).index(·) that can be replaced with graph_nodes.index(·)

Fixed.

layer/graph_attention.py

  • I think we need to subtract the row maximum from dense before the exponential to avoid floating point errors in the exp function:
W = (
    (1 - self.non_exist_edge) * self.delta * A
    + self.non_exist_edge
    * (
        A
        + self.delta * (K.ones(shape=[N, N], dtype=*”float”*) - A)
        + K.eye(N)
    )
) * K.exp(dense - K.max(dense, axis=1, keepdims=True))
dense = W / K.sum(W, axis=1, keepdims=True)

This means the current tests fail, but I don’t believe this is bad – the results from this are different from the implementation without the subtraction. The results in the notebook are the same with this normalisation.

There is, as you point out, an issue with the number of non-zero elements in the link importance calculation; however, I think this is just an issue of floating point accuracy, and counting elements above a small threshold works fine: i.e. using the following

print("Number of non-zero elements in integrate_link_importance: {}".format(np.sum(np.abs(integrate_link_importance) > 1e-8)))

gives:

Number of edges in the ego graph: 210
Number of non-zero elements in integrate_link_importance: 210

That's interesting. It does explain the previous test failure. Fixed in the tests.

I would like to add a unit test to the tests/layer/test_graph_attention.py file that checks that the results from the implementation with saliency_map_support=False are the same as those for saliency_map_support=True – such as the test_apply_average_with_neighbours method. Would you like to add this test?

I have changed the test to add the GAT model with saliency map support as well.

utils/saliency_maps_gat

  • I think we should move these files to utils/saliency_maps and name them integrated_gradients_gat.py and saliency_gat.py . Additionally rename the IntegratedGradients class to IntegratedGradientsGAT and GradientSaliency to GradientSaliencyGAT. This way we can import all saliency objects to the same namespace in utils/saliency_maps/__init__.py.

Fixed.

  • As in the comments on the notebook above, let’s use the node IDs instead of index in all functions.

Yes, fixed.

  • In get_integrated_node_masks the features are taken from zero to one.

It's actually from baseline (does not necessarily be 0) to the current state of X (which does not necessarily be 1).

  • This seems to assume that the features are all binary. What happens if they are not? We should state this assumption clearly in the class documentation.

Therefore, we do not assume binary features.

  • This seems to only consider the importance of features that are one, as any features that are zero will be zero for all steps. I would have guessed that there would also be a function that takes those features from 1 to 0, you talked about this in the paper but I forget what you said now!

Somehow I did not implement that in GAT. Fixed now.

  • There should be a description of what each method does and the assumptions made.

Fixed.

  • In the argument list try not to put multiple variables on the same line, rather describe each variable separately on its own line.

Fixed.

Copy link
Contributor

@adocherty adocherty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to only consider the importance of features that are one, as any features that are zero will be zero for all steps. I would have guessed that there would also be a function that takes those features from 1 to 0, you talked about this in the paper but I forget what you said now!

Somehow I did not implement that in GAT. Fixed now.

You have introduced a flag which selects if the we should set the baseline to zero and go to X_val for all features, or set the baseline to X_val and go to one for all features. I was thinking that, for binary features at least, we should set the baseline to 1-X_val then the IG can calculate the change from 1 to 0 or 0 to 1, i.e. what will happen when the feature is different. Thus when we calculate node importance, we don't have to do so twice (once for features which are 1 and then again for features which are 0). What do you think?

Returns:
gradients (Numpy array): Returns a vanilla gradient mask for the nodes.
"""
out_indices = np.array([[node_idx]])
out_indices = np.array([[node_id]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we would look up the node index given the node IDs here? For example, the FullBatchNodeGenerator does this using the graph node_list :

node_indices = np.array([self.node_list.index(n) for n in node_ids])

We should do something similar here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also do this for the original GCN saliency methods. I've added an issue for this: #466

X_diff = X_val - X_baseline
total_gradients = np.zeros(X_val.shape)

for alpha in np.linspace(1.0 / steps, 1, steps):
X_step = X_baseline + alpha * X_diff
total_gradients += super(IntegratedGradients, self).get_node_masks(
node_idx, class_of_interest, X_val=X_step
total_gradients += super(IntegratedGradientsGAT, self).get_node_masks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be:

total_gradients += super().get_node_masks(
                node_id, class_of_interest, X_val=X_step)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -94,8 +106,8 @@ def get_integrated_link_masks(
for alpha in np.linspace(1.0 / steps, 1.0, steps):
if self.is_sparse:
A_val = sp.lil_matrix(A_val)
tmp = super(IntegratedGradients, self).get_link_masks(
alpha, node_idx, class_of_interest, int(non_exist_edge)
tmp = super(IntegratedGradientsGAT, self).get_link_masks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just have super() here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@adocherty
Copy link
Contributor

IG does seem to work well for binary features compared with vanilla methods. However, it does not assume binary features. In fact, it was initially used in the image domain where features are not binary.

OK, thanks for clarifying! We can mention this in the class docstring, particularly the importance of setting X_baseline appropriately.

@sktzwhj
Copy link
Contributor Author

sktzwhj commented Aug 6, 2019

You have introduced a flag which selects if the we should set the baseline to zero and go to X_val for all features, or set the baseline to X_val and go to one for all features. I was thinking that, for binary features at least, we should set the baseline to 1-X_val then the IG can calculate the change from 1 to 0 or 0 to 1, i.e. what will happen when the feature is different. Thus when we calculate node importance, we don't have to do so twice (once for features which are 1 and then again for features which are 0). What do you think?

I tend to keep as is. Although we aimed at binary features as the motivation initially, we should not make the implementation to be specifically for that. If the features are not binary, the path from 0 -> X_val and X_val -> 1 are different, thus leading to no space for the optimization described above. Also, under the existing frameworks, calculating the gradients for part of the matrix does not really bring much performance improvements I think.

@sktzwhj
Copy link
Contributor Author

sktzwhj commented Aug 6, 2019

We need to think about how to return the link importance values to align with the node id rather than node indices in the parameters of saliency maps. In other words, we should not let the users do the manual re-mapping by themselves.

@adocherty
Copy link
Contributor

We need to think about how to return the link importance values to align with the node id rather than node indices in the parameters of saliency maps. In other words, we should not let the users do the manual re-mapping by themselves.

That is a good point, I didn't think about that. I think we should add this as a future issue.

I'm happy with the changes!

@sktzwhj sktzwhj merged commit 3463008 into develop Aug 8, 2019
@youph youph deleted the gat-saliency-map branch January 2, 2020 00:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants